{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e1qfiZMOJYYv"
      },
      "source": [
        "# Graphbolt Quick Walkthrough\n",
        "\n",
        "The tutorial provides a quick walkthrough of operators provided by the `dgl.graphbolt` package, and illustrates how to create a GNN datapipe with the package. To learn more details about Stochastic Training of GNNs, please read the [materials](https://docs.dgl.ai/tutorials/large/index.html) provided by DGL.\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fWiaC1WaDE-W"
      },
      "outputs": [],
      "source": [
        "# Install required packages.\n",
        "import os\n",
        "import torch\n",
        "os.environ['TORCH'] = torch.__version__\n",
        "os.environ['DGLBACKEND'] = \"pytorch\"\n",
        "\n",
        "# Install the CPU version.\n",
        "device = torch.device(\"cpu\")\n",
        "!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\n",
        "\n",
        "try:\n",
        "    import dgl.graphbolt as gb\n",
        "    installed = True\n",
        "except ImportError as error:\n",
        "    installed = False\n",
        "    print(error)\n",
        "print(\"DGL installed!\" if installed else \"DGL not found!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8O7PfsY4sPoN"
      },
      "source": [
        "## Dataset\n",
        "\n",
        "The dataset has three primary components. *1*. An itemset, which can be iterated over as the training target. *2*. A sampling graph, which is used by the subgraph sampling algorithm to generate a subgraph. *3*. A feature store, which stores node, edge, and graph features.\n",
        "\n",
        "* The **Itemset** is created from iterable data or tuple of iterable data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g73ZAbMQsSgV"
      },
      "outputs": [],
      "source": [
        "seeds = torch.tensor(\n",
        "    [[7, 0], [6, 0], [1, 3], [3, 3], [2, 4], [8, 4], [1, 4], [2, 4], [1, 5],\n",
        "     [9, 6], [0, 6], [8, 6], [7, 7], [7, 7], [4, 7], [6, 8], [5, 8], [9, 9],\n",
        "     [4, 9], [4, 9], [5, 9], [9, 9], [5, 9], [9, 9], [7, 9]]\n",
        ")\n",
        "item_set = gb.ItemSet(seeds, names=\"seeds\")\n",
        "print(list(item_set))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lqty9p4cs0OR"
      },
      "source": [
        "* The **SamplingGraph** is used by the subgraph sampling algorithm to generate a subgraph. In graphbolt, we provide a canonical solution, the FusedCSCSamplingGraph, which achieves state-of-the-art time and space efficiency on CPU sampling. However, this requires enough CPU memory to host all FusedCSCSamplingGraph objects in memory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jDjY149xs3PI"
      },
      "outputs": [],
      "source": [
        "indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n",
        "indices = torch.tensor(\n",
        "    [7, 6, 1, 3, 2, 8, 1, 2, 1, 9, 0, 8, 7, 7, 4, 6, 5, 9, 4, 4, 5, 9, 5, 9, 7]\n",
        ")\n",
        "num_edges = 25\n",
        "eid = torch.arange(num_edges)\n",
        "edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\n",
        "graph = gb.fused_csc_sampling_graph(indptr, indices, edge_attributes=edge_attributes)\n",
        "print(graph)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mNp2S2_Vs8af"
      },
      "source": [
        "* The **FeatureStore** is used to store node, edge, and graph features. In graphbolt, we provide the TorchBasedFeature and related optimizations, such as the GPUCachedFeature, for different use cases."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zIU6KWe1Sm2g"
      },
      "outputs": [],
      "source": [
        "num_nodes = 10\n",
        "num_edges = 25\n",
        "node_feature_data = torch.rand((num_nodes, 2))\n",
        "edge_feature_data = torch.rand((num_edges, 3))\n",
        "node_feature = gb.TorchBasedFeature(node_feature_data)\n",
        "edge_feature = gb.TorchBasedFeature(edge_feature_data)\n",
        "features = {\n",
        "    (\"node\", None, \"feat\") : node_feature,\n",
        "    (\"edge\", None, \"feat\") : edge_feature,\n",
        "}\n",
        "feature_store = gb.BasicFeatureStore(features)\n",
        "print(feature_store)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Oh2ockWWoXQ0"
      },
      "source": [
        "## DataPipe\n",
        "\n",
        "The DataPipe in Graphbolt is an extension of the PyTorch DataPipe, but it is specifically designed to address the challenges of training graph neural networks (GNNs). Each stage of the data pipeline loads data from different sources and can be combined with other stages to create more complex data pipelines. The intermediate data will be stored in **MiniBatch** data packs.\n",
        "\n",
        "* **ItemSampler** iterates over input **Itemset** and create subsets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XtqPDprrogR7"
      },
      "outputs": [],
      "source": [
        "datapipe = gb.ItemSampler(item_set, batch_size=3, shuffle=False)\n",
        "print(next(iter(datapipe)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BjkAK37xopp1"
      },
      "source": [
        "* **NegativeSampler** generate negative samples and return a mix of positive and negative samples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PrFpGoOGopJy"
      },
      "outputs": [],
      "source": [
        "datapipe = datapipe.sample_uniform_negative(graph, 1)\n",
        "print(next(iter(datapipe)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fYO_oIwkpmb3"
      },
      "source": [
        "* **SubgraphSampler** samples a subgraph from a given set of nodes from a larger graph."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4UsY3PL3ppYV"
      },
      "outputs": [],
      "source": [
        "fanouts = torch.tensor([1])\n",
        "datapipe = datapipe.sample_neighbor(graph, [fanouts])\n",
        "print(next(iter(datapipe)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0uIydsjUqMA0"
      },
      "source": [
        "* **FeatureFetcher** fetchs features for node/edge in graphbolt."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YAj8G7YBqO6G"
      },
      "outputs": [],
      "source": [
        "datapipe = datapipe.fetch_feature(feature_store, node_feature_keys=[\"feat\"], edge_feature_keys=[\"feat\"])\n",
        "print(next(iter(datapipe)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hjBSLPRPrsD2"
      },
      "source": [
        "* Copy the data to the GPU for training on the GPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RofiZOUMqt_u"
      },
      "outputs": [],
      "source": [
        "datapipe = datapipe.copy_to(device=device)\n",
        "print(next(iter(datapipe)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xm9HnyHRvxXj"
      },
      "source": [
        "## Exercise: Node classification\n",
        "\n",
        "Similarly, the following Dataset is created for node classification, can you implement the data pipeline for the dataset?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YV-mk-xAv78v"
      },
      "outputs": [],
      "source": [
        "# Dataset for node classification.\n",
        "num_nodes = 10\n",
        "nodes = torch.arange(num_nodes)\n",
        "labels = torch.tensor([1, 2, 0, 2, 2, 0, 2, 2, 2, 2])\n",
        "item_set = gb.ItemSet((nodes, labels), names=(\"seeds\", \"labels\"))\n",
        "\n",
        "indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n",
        "indices = torch.tensor(\n",
        "    [7, 6, 1, 3, 2, 8, 1, 2, 1, 9, 0, 8, 7, 7, 4, 6, 5, 9, 4, 4, 5, 9, 5, 9, 7]\n",
        ")\n",
        "eid = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n",
        "                    14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])\n",
        "edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\n",
        "graph = gb.from_fused_csc(indptr, indices, None, None, edge_attributes, None)\n",
        "\n",
        "num_nodes = 10\n",
        "num_edges = 25\n",
        "node_feature_data = torch.rand((num_nodes, 2))\n",
        "edge_feature_data = torch.rand((num_edges, 3))\n",
        "node_feature = gb.TorchBasedFeature(node_feature_data)\n",
        "edge_feature = gb.TorchBasedFeature(edge_feature_data)\n",
        "features = {\n",
        "    (\"node\", None, \"feat\") : node_feature,\n",
        "    (\"edge\", None, \"feat\") : edge_feature,\n",
        "}\n",
        "feature_store = gb.BasicFeatureStore(features)\n",
        "\n",
        "# Datapipe.\n",
        "...\n",
        "print(next(iter(datapipe)))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "BjkAK37xopp1"
      ],
      "gpuType": "T4",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
